package com.hcloud.gateway.filter;

import com.hcloud.common.redis.util.RedisUtil;
import io.netty.buffer.ByteBufAllocator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

/**
 * @Auther hepangui
 * @Date 2018/11/12
 */
@Slf4j
@Component
public class ImageCodeFilter extends AbstractGatewayFilterFactory {
    public static final String OAUTH_TOKEN_URL = "/oauth/token";
    private static final String COKE_KEY_PREFIX = "CODE_";

    @Override
    public GatewayFilter apply(Object config) {
        return new GatewayFilter() {
            @Override
            public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
                ServerHttpRequest serverHttpRequest = exchange.getRequest();

                // 不是登录请求，直接向下执行
                if (serverHttpRequest.getURI().getPath() == null || !serverHttpRequest.getURI().getPath().contains(OAUTH_TOKEN_URL)) {
                    return chain.filter(exchange);
                }

                //从请求里获取Post请求体
                String bodyStr = resolveBodyFromRequest(serverHttpRequest);
                Map<String, String> paraMap = bodyToMap(bodyStr);
                System.out.println(bodyStr);
                //校验验证码合法性
                String code = paraMap.get("code");
                String key = paraMap.get("randomUuid");
                if (!checkCode(key, code)) {
                    ServerHttpResponse response = exchange.getResponse();
                    response.setStatusCode(HttpStatus.PRECONDITION_REQUIRED);
                    return response.setComplete();
                }

                //下面的将请求体再次封装写回到request里，传到下一级，否则，由于请求体已被消费，后续的服务将取不到值
                URI uri = serverHttpRequest.getURI();
                ServerHttpRequest request = serverHttpRequest.mutate().uri(uri).build();
                DataBuffer bodyDataBuffer = stringBuffer(bodyStr);
                Flux<DataBuffer> bodyFlux = Flux.just(bodyDataBuffer);

                request = new ServerHttpRequestDecorator(request) {
                    @Override
                    public Flux<DataBuffer> getBody() {
                        return bodyFlux;
                    }
                };
                //封装request，传给下一级
                return chain.filter(exchange.mutate().request(request).build());


//                return chain.filter(exchange);
            }
        };
    }

    private boolean checkCode(String key, String code) {
        if (key != null && !"".equals(key) && code != null && !"".equals(code)) {
            Object o = RedisUtil.get(COKE_KEY_PREFIX + key);
            RedisUtil.del(COKE_KEY_PREFIX + key);//删除验证码
            if (o != null && code.equalsIgnoreCase((String) o)) {
                return true;
            }
        }
        return false;
    }


    /**
     * 从Flux<DataBuffer>中获取字符串的方法
     *
     * @return 请求体
     */
    private String resolveBodyFromRequest(ServerHttpRequest serverHttpRequest) {
        //获取请求体
        Flux<DataBuffer> body = serverHttpRequest.getBody();

        AtomicReference<String> bodyRef = new AtomicReference<>();
        body.subscribe(buffer -> {
            CharBuffer charBuffer = StandardCharsets.UTF_8.decode(buffer.asByteBuffer());
            DataBufferUtils.release(buffer);
            bodyRef.set(charBuffer.toString());
        });
        //获取request body
        return bodyRef.get();
    }


    private DataBuffer stringBuffer(String value) {
        byte[] bytes = value.getBytes(StandardCharsets.UTF_8);

        NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }

    private Map<String, String> bodyToMap(String body) {
        Map<String, String> map = new HashMap<>();
        if (body != null) {
            String[] split = body.split("&");
            for (String s : split) {
                if (s.indexOf("=") > -1) {
                    String[] split1 = s.split("=");
                    map.put(split1[0], split1[1]);
                } else {
                    map.put(s, "");
                }
            }
        }
        return map;
    }

}
